# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import numpy as np

from megengine.functional.tensor import zeros

from ..core.ops.builtin import BatchNorm
from .expr import CallMethod, Constant
from .node import TensorNode
from .serialization import (
    register_functional_loader,
    register_module_loader,
    register_opdef_loader,
    register_tensor_method_loader,
)


"""
# Expr loaders examples

from ..core.ops.builtin import Elemwise

@register_opdef_loader(Elemwise)
def add_opdef_loader(expr):
    if expr.opdef_state["mode"] == "ADD":
        expr.opdef_state["mode"] == "MUL"
        node = expr.inputs[1]
        astype_expr = CallMethod(node, "astype")
        oup = TensorNode(
            astype_expr,
            shape=node.shape,
            dtype=expr.inputs[0].dtype,
            qparams=node.qparams,
        )

        astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
        astype_expr.return_val = (oup,)
        expr.inputs[1] = oup


@register_functional_loader(("megengine.functional.nn", "conv2d"))
def conv2df_loader(expr):
    # expr.func = ("megengine.functional.nn","conv2d")
    kwargs = expr.kwargs
    orig_weight = expr.named_args["weight"]

    astype_expr = CallMethod(orig_weight, "astype")
    oup = TensorNode(
        astype_expr,
        shape=orig_weight.shape,
        dtype=orig_weight.dtype,
        qparams=orig_weight.qparams,
    )

    astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
    astype_expr.return_val = (oup,)

    expr.set_arg("weight", oup)


@register_module_loader(("megengine.module.conv", "Conv2d"))
def conv2dm_loader(expr):
    module = expr.inputs[0].owner
    args = list(expr.args)
    orig_inp = args[1]
    astype_expr = CallMethod(orig_inp, "astype")
    oup = TensorNode(
        astype_expr,
        shape=orig_inp.shape,
        dtype=orig_inp.dtype,
        qparams=orig_inp.qparams,
    )
    astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
    astype_expr.return_val = (oup,)
    args[1] = oup
    expr.set_args_kwargs(*args)


@register_tensor_method_loader("__add__")
def add_loader(expr):
    args = list(expr.args)
    if not isinstance(args[1], TensorNode):
        args[1] = tensor(args[1])
        node = Constant(args[1], "const").outputs[0]

        astype_expr = CallMethod(node, "astype")
        oup = TensorNode(
            astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
        )

        astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
        astype_expr.return_val = (oup,)
        args[1] = oup
        expr.set_args_kwargs(*args)
"""


@register_module_loader(
    ("megengine.module.batchnorm", "BatchNorm1d"),
    ("megengine.module.batchnorm", "BatchNorm2d"),
    ("megengine.module.batchnorm", "SyncBatchNorm"),
)
def bn2d_module_loader(expr):
    # mge 1.6
    if not hasattr(expr, "version"):
        module = expr.inputs[0].owner
        if not hasattr(module, "param_dim"):
            module.param_dim = "dim_1c11"


@register_module_loader(
    ("megengine.module.conv_bn", "ConvBn2d"),
    ("megengine.module.conv_bn", "ConvBnRelu2d"),
    ("megengine.module.qat.conv_bn", "ConvBn2d"),
    ("megengine.module.qat.conv_bn", "ConvBnRelu2d"),
)
def convbn2d_module_loader(expr):
    # mge 1.6
    if not hasattr(expr, "version"):
        module = expr.inputs[0].owner
        if not hasattr(module.bn, "param_dim"):
            module.bn.param_dim = "dim_1c11"
    module = expr.inputs[0].owner
    if not hasattr(module.conv, "padding_mode"):
        module.conv.padding_mode = "zeros"


@register_opdef_loader(BatchNorm)
def bn_opdef_loader(expr):
    # mge 1.6
    if not hasattr(expr, "version") and len(expr.outputs) != 6:
        assert len(expr.outputs) == 5
        output = expr.outputs[-1]
        oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,)
        expr.outputs.insert(4, oup)


@register_functional_loader(
    ("megengine.functional.tensor", "ones"), ("megengine.functional.tensor", "zeros")
)
def tensor_gen_func_loader(expr):
    if hasattr(expr, "version") and expr.version == "1.7.0":
        expr.set_args_kwargs(expr.args[0], dtype=expr.args[1], device=expr.args[2])
    if not hasattr(expr, "version"):
        # compatiable for version 1.6
        shape = expr.args[0] if len(expr.args) > 0 else expr.kwargs["shape"]

        if len(expr.args) > 1:
            dtype = expr.args[1]
        elif "dtype" in expr.kwargs:
            dtype = expr.kwargs["dtype"]
        else:
            dtype = "float32"

        if len(expr.args) > 2:
            device = expr.args[2]
        elif "device" in expr.kwargs:
            device = expr.kwargs["device"]
        else:
            device = None
        expr.set_args_kwargs(shape, dtype=dtype, device=device)


@register_functional_loader(("megengine.functional.nn", "pad"))
def pad_func_loader(expr):
    if "pad_witdth" in expr.kwargs:
        kwargs = expr.kwargs
        kwargs["pad_width"] = kwargs.pop("pad_witdth")
        expr.set_args_kwargs(*expr.args, **kwargs)


@register_module_loader(
    ("megengine.module.conv", "Conv1d"),
    ("megengine.module.conv", "Conv2d"),
    ("megengine.module.conv", "ConvRelu2d"),
    ("megengine.module.qat.conv", "Conv2d"),
    ("megengine.module.qat.conv", "ConvRelu2d"),
    ("megengine.module.quantized.conv", "Conv2d"),
    ("megengine.module.quantized.conv", "ConvRelu2d"),
)
def conv2d_module_loader(expr):
    module = expr.inputs[0].owner
    if not hasattr(module, "padding_mode"):
        module.padding_mode = "zeros"


@register_module_loader(
    ("megengine.module.quantized.conv_bn", "ConvBn2d"),
    ("megengine.module.quantized.conv_bn", "ConvBnRelu2d"),
)
def quantized_convbn2d_module_loader(expr):
    module = expr.inputs[0].owner
    if not hasattr(module, "padding_mode"):
        module.padding_mode = "zeros"
